'''
Author: zhanwei xu
Date: 2023-05-27 18:50:03
LastEditors: zhanwei xu
LastEditTime: 2023-06-14 17:14:07
Description: 

Copyright (c) 2023 by zhanwei xu, Tsinghua University, All Rights Reserved. 
'''

from fastapi import APIRouter, Request
import requests
import httpx
from fastapi.responses import StreamingResponse
import json
from .utils import async_request, get_api_keys, get_api_key, check_traffic
draw_router = APIRouter()
api_keys = get_api_keys("api_keys.txt")
global api_index 
api_index = 0
@draw_router.post("/img_generate")
async def forward_img_generate(request: Request):

    
    client_ip = request.client.host

    if check_traffic(client_ip):
        return {'error':'每24小时只能使用100次，请明天再来'}
        
    global api_index
    # 获取前端请求的数据
    request_data = await request.json()
    prompt = request_data["prompt"]
    negative_prompt = request_data["negative_prompt"]
    model_name = request_data["model_name"] # "normal","structure","anime"
    
    original_image = ""
    mask_image = ""
    try:
        original_image = request_data['original_image']
        mask_image = request_data['mask_image']
    except:
        pass

    
    message = [{'role':'user','content':"Translate it into English: "+prompt}]
    api_key, api_index = get_api_key(api_keys, api_index)
    print('adfadsfadsafdfasdfadadfadfadfadfadfadfadsadf')
    response = await async_request(message,api_key)
    prompt = ''
    for event in response.iter_content(chunk_size=None):     
            response_str = event.decode('utf-8')  # convert bytes to string
            strs = response_str.split('data: ')[1:]
            for str in strs:
                try:
                    response_dict = json.loads(str)  # convert string to dictionary
                    if "choices" in response_dict and response_dict["choices"]:
                        text = response_dict["choices"][0]["delta"].get("content")
                        if text:
                            text = text.replace('\\n','\n').replace('\\t','\t')
                            
                            text = text.replace('\\n','\n')
                            prompt += text
                except:
                    pass
    
    message = [{'role':'user','content':"Translate it into English: "+negative_prompt}]
    
    response = await async_request(message,api_keys[0])
    negative_prompt = ''
    for event in response.iter_content(chunk_size=None):     
            response_str = event.decode('utf-8')  # convert bytes to string
            strs = response_str.split('data: ')[1:]
            for str in strs:
                try:
                    response_dict = json.loads(str)  # convert string to dictionary
                    if "choices" in response_dict and response_dict["choices"]:
                        text = response_dict["choices"][0]["delta"].get("content")
                        if text:
                            text = text.replace('\\n','\n').replace('\\t','\t')
                            text = text.replace('\\n','\n')
                            negative_prompt += text
                except:
                    pass
    # 构造请求的request_data 用新的prompt和negative_prompt
    request_data = {
        "prompt": prompt,
        "negative_prompt": negative_prompt,
        "model_name": model_name,
        "original_image": original_image,
        "mask_image": mask_image
    }
    # 构造请求的URL
    url = "http://107.173.221.112:9070/img_generate"
    # 发送HTTP请求至现有后端并获取响应
    response = requests.post(url, json=request_data)
    # 将现有后端返回的响应再次返回给前端
    return response.json()

@draw_router.get("/img_generate_task/{task_id}")
async def forward_img_generate_task_status(task_id: str):
    async with httpx.AsyncClient() as client:
        response = await client.get(f"http://107.173.221.112:9070/img_generate_task/{task_id}")

    return StreamingResponse(response.aiter_bytes(), media_type=response.headers["Content-Type"], headers=dict(response.headers), status_code=response.status_code)




